export const description = `Test compute shader builtin variables`;

import { makeTestGroup } from '../../../../common/framework/test_group.js';
import { iterRange } from '../../../../common/util/util.js';
import { GPUTest } from '../../../gpu_test.js';

export const g = makeTestGroup(GPUTest);

// Test that the values for each input builtin are correct.
g.test('inputs')
  .desc(`Test compute shader builtin inputs values`)
  .params(u =>
    u
      .combine('method', ['param', 'struct', 'mixed'] as const)
      .combine('dispatch', ['direct', 'indirect'] as const)
      .combineWithParams([
        {
          groupSize: { x: 1, y: 1, z: 1 },
          numGroups: { x: 1, y: 1, z: 1 },
        },
        {
          groupSize: { x: 8, y: 4, z: 2 },
          numGroups: { x: 1, y: 1, z: 1 },
        },
        {
          groupSize: { x: 1, y: 1, z: 1 },
          numGroups: { x: 8, y: 4, z: 2 },
        },
        {
          groupSize: { x: 3, y: 7, z: 5 },
          numGroups: { x: 13, y: 9, z: 11 },
        },
      ] as const)
      .beginSubcases()
  )
  .fn(async t => {
    const invocationsPerGroup = t.params.groupSize.x * t.params.groupSize.y * t.params.groupSize.z;
    const totalInvocations =
      invocationsPerGroup * t.params.numGroups.x * t.params.numGroups.y * t.params.numGroups.z;

    // Generate the structures, parameters, and builtin expressions used in the shader.
    let params = '';
    let structures = '';
    let local_id = '';
    let local_index = '';
    let global_id = '';
    let group_id = '';
    let num_groups = '';
    switch (t.params.method) {
      case 'param':
        params = `
          [[builtin(local_invocation_id)]] local_id : vec3<u32>,
          [[builtin(local_invocation_index)]] local_index : u32,
          [[builtin(global_invocation_id)]] global_id : vec3<u32>,
          [[builtin(workgroup_id)]] group_id : vec3<u32>,
          [[builtin(num_workgroups)]] num_groups : vec3<u32>,
        `;
        local_id = 'local_id';
        local_index = 'local_index';
        global_id = 'global_id';
        group_id = 'group_id';
        num_groups = 'num_groups';
        break;
      case 'struct':
        structures = `struct Inputs {
            [[builtin(local_invocation_id)]] local_id : vec3<u32>;
            [[builtin(local_invocation_index)]] local_index : u32;
            [[builtin(global_invocation_id)]] global_id : vec3<u32>;
            [[builtin(workgroup_id)]] group_id : vec3<u32>;
            [[builtin(num_workgroups)]] num_groups : vec3<u32>;
          };`;
        params = `inputs : Inputs`;
        local_id = 'inputs.local_id';
        local_index = 'inputs.local_index';
        global_id = 'inputs.global_id';
        group_id = 'inputs.group_id';
        num_groups = 'inputs.num_groups';
        break;
      case 'mixed':
        structures = `struct InputsA {
          [[builtin(local_invocation_index)]] local_index : u32;
          [[builtin(global_invocation_id)]] global_id : vec3<u32>;
        };
        struct InputsB {
          [[builtin(workgroup_id)]] group_id : vec3<u32>;
        };`;
        params = `[[builtin(local_invocation_id)]] local_id : vec3<u32>,
                  inputsA : InputsA,
                  inputsB : InputsB,
                  [[builtin(num_workgroups)]] num_groups : vec3<u32>,`;
        local_id = 'local_id';
        local_index = 'inputsA.local_index';
        global_id = 'inputsA.global_id';
        group_id = 'inputsB.group_id';
        num_groups = 'num_groups';
        break;
    }

    // WGSL shader that stores every builtin value to a buffer, for every invocation in the grid.
    const wgsl = `
      struct S {
        data : array<u32>;
      };
      struct V {
        data : array<vec3<u32>>;
      };
      [[group(0), binding(0)]] var<storage, write> local_id_out : V;
      [[group(0), binding(1)]] var<storage, write> local_index_out : S;
      [[group(0), binding(2)]] var<storage, write> global_id_out : V;
      [[group(0), binding(3)]] var<storage, write> group_id_out : V;
      [[group(0), binding(4)]] var<storage, write> num_groups_out : V;

      ${structures}

      let group_width = ${t.params.groupSize.x}u;
      let group_height = ${t.params.groupSize.y}u;
      let group_depth = ${t.params.groupSize.z}u;

      [[stage(compute), workgroup_size(group_width, group_height, group_depth)]]
      fn main(
        ${params}
        ) {
        let group_index = ((${group_id}.z * ${num_groups}.y) + ${group_id}.y) * ${num_groups}.x + ${group_id}.x;
        let global_index = group_index * ${invocationsPerGroup}u + ${local_index};
        local_id_out.data[global_index] = ${local_id};
        local_index_out.data[global_index] = ${local_index};
        global_id_out.data[global_index] = ${global_id};
        group_id_out.data[global_index] = ${group_id};
        num_groups_out.data[global_index] = ${num_groups};
      }
    `;

    const pipeline = t.device.createComputePipeline({
      compute: {
        module: t.device.createShaderModule({
          code: wgsl,
        }),
        entryPoint: 'main',
      },
    });

    // Helper to create a `size`-byte buffer with binding number `binding`.
    function createBuffer(size: number, binding: number) {
      const buffer = t.device.createBuffer({
        size,
        usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
      });
      t.trackForCleanup(buffer);

      bindGroupEntries.push({
        binding,
        resource: {
          buffer,
        },
      });

      return buffer;
    }

    // Create the output buffers.
    const bindGroupEntries: GPUBindGroupEntry[] = [];
    const localIdBuffer = createBuffer(totalInvocations * 16, 0);
    const localIndexBuffer = createBuffer(totalInvocations * 4, 1);
    const globalIdBuffer = createBuffer(totalInvocations * 16, 2);
    const groupIdBuffer = createBuffer(totalInvocations * 16, 3);
    const numGroupsBuffer = createBuffer(totalInvocations * 16, 4);

    const bindGroup = t.device.createBindGroup({
      layout: pipeline.getBindGroupLayout(0),
      entries: bindGroupEntries,
    });

    // Run the shader.
    const encoder = t.device.createCommandEncoder();
    const pass = encoder.beginComputePass();
    pass.setPipeline(pipeline);
    pass.setBindGroup(0, bindGroup);
    switch (t.params.dispatch) {
      case 'direct':
        pass.dispatch(t.params.numGroups.x, t.params.numGroups.y, t.params.numGroups.z);
        break;
      case 'indirect': {
        const dispatchBuffer = t.device.createBuffer({
          size: 3 * Uint32Array.BYTES_PER_ELEMENT,
          usage: GPUBufferUsage.INDIRECT,
          mappedAtCreation: true,
        });
        t.trackForCleanup(dispatchBuffer);
        const dispatchData = new Uint32Array(dispatchBuffer.getMappedRange());
        dispatchData[0] = t.params.numGroups.x;
        dispatchData[1] = t.params.numGroups.y;
        dispatchData[2] = t.params.numGroups.z;
        dispatchBuffer.unmap();
        pass.dispatchIndirect(dispatchBuffer, 0);
        break;
      }
    }
    pass.endPass();
    t.queue.submit([encoder.finish()]);

    type vec3 = { x: number; y: number; z: number };

    // Helper to check that the vec3<u32> value at each index of the provided `output` buffer
    // matches the expected value for that invocation, as generated by the `getBuiltinValue`
    // function. The `name` parameter is the builtin name, used for error messages.
    const checkEachIndex = (
      output: Uint32Array,
      name: string,
      getBuiltinValue: (groupId: vec3, localId: vec3) => vec3
    ) => {
      // Loop over workgroups.
      for (let gz = 0; gz < t.params.numGroups.z; gz++) {
        for (let gy = 0; gy < t.params.numGroups.y; gy++) {
          for (let gx = 0; gx < t.params.numGroups.x; gx++) {
            // Loop over invocations within a group.
            for (let lz = 0; lz < t.params.groupSize.z; lz++) {
              for (let ly = 0; ly < t.params.groupSize.y; ly++) {
                for (let lx = 0; lx < t.params.groupSize.x; lx++) {
                  const groupIndex = (gz * t.params.numGroups.y + gy) * t.params.numGroups.x + gx;
                  const localIndex = (lz * t.params.groupSize.y + ly) * t.params.groupSize.x + lx;
                  const globalIndex = groupIndex * invocationsPerGroup + localIndex;
                  const expected = getBuiltinValue(
                    { x: gx, y: gy, z: gz },
                    { x: lx, y: ly, z: lz }
                  );
                  if (output[globalIndex * 4 + 0] !== expected.x) {
                    return new Error(
                      `${name}.x failed at group(${gx},${gy},${gz}) local(${lx},${ly},${lz}))\n` +
                        `    expected: ${expected.x}\n` +
                        `    got:      ${output[globalIndex * 4 + 0]}`
                    );
                  }
                  if (output[globalIndex * 4 + 1] !== expected.y) {
                    return new Error(
                      `${name}.y failed at group(${gx},${gy},${gz}) local(${lx},${ly},${lz}))\n` +
                        `    expected: ${expected.y}\n` +
                        `    got:      ${output[globalIndex * 4 + 1]}`
                    );
                  }
                  if (output[globalIndex * 4 + 2] !== expected.z) {
                    return new Error(
                      `${name}.z failed at group(${gx},${gy},${gz}) local(${lx},${ly},${lz}))\n` +
                        `    expected: ${expected.z}\n` +
                        `    got:      ${output[globalIndex * 4 + 2]}`
                    );
                  }
                }
              }
            }
          }
        }
      }
      return undefined;
    };

    // Check [[builtin(local_invocation_index)]] values.
    t.expectGPUBufferValuesEqual(
      localIndexBuffer,
      new Uint32Array([...iterRange(totalInvocations, x => x % invocationsPerGroup)])
    );

    // Check [[builtin(local_invocation_id)]] values.
    t.expectGPUBufferValuesPassCheck(
      localIdBuffer,
      outputData => checkEachIndex(outputData, 'local_invocation_id', (_, localId) => localId),
      { type: Uint32Array, typedLength: totalInvocations * 4 }
    );

    // Check [[builtin(global_invocation_id)]] values.
    const getGlobalId = (groupId: vec3, localId: vec3) => {
      return {
        x: groupId.x * t.params.groupSize.x + localId.x,
        y: groupId.y * t.params.groupSize.y + localId.y,
        z: groupId.z * t.params.groupSize.z + localId.z,
      };
    };
    t.expectGPUBufferValuesPassCheck(
      globalIdBuffer,
      outputData => checkEachIndex(outputData, 'global_invocation_id', getGlobalId),
      { type: Uint32Array, typedLength: totalInvocations * 4 }
    );

    // Check [[builtin(workgroup_id)]] values.
    t.expectGPUBufferValuesPassCheck(
      groupIdBuffer,
      outputData => checkEachIndex(outputData, 'workgroup_id', (groupId, _) => groupId),
      { type: Uint32Array, typedLength: totalInvocations * 4 }
    );

    // Check [[builtin(num_workgroups)]] values.
    t.expectGPUBufferValuesPassCheck(
      numGroupsBuffer,
      outputData => checkEachIndex(outputData, 'num_workgroups', () => t.params.numGroups),
      { type: Uint32Array, typedLength: totalInvocations * 4 }
    );
  });
